{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Part 12: Train an Encrypted NN on Encrypted Data\n",
    "\n",
    "In this notebook, we're going to use all the techniques we've learned thus far to perform neural network training (and prediction) while both the model and the data are encrypted.\n",
    "\n",
    "In particular, we present our custom Autograd engine which works on encrypted computations.\n",
    "\n",
    "Authors:\n",
    "- Andrew Trask - Twitter: [@iamtrask](https://twitter.com/iamtrask)\n",
    "- Jason Paumier - Github: [@Jasopaum](https://github.com/Jasopaum)\n",
    "- Théo Ryffel - Twitter: [@theoryffel](https://twitter.com/theoryffel)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Step 1: Create Workers and Toy Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import syft as sy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set everything up\n",
    "hook = sy.TorchHook(torch) \n",
    "\n",
    "alice = sy.VirtualWorker(id=\"alice\", hook=hook)\n",
    "bob = sy.VirtualWorker(id=\"bob\", hook=hook)\n",
    "james = sy.VirtualWorker(id=\"james\", hook=hook)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A Toy Dataset\n",
    "data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]])\n",
    "target = torch.tensor([[0],[0],[1],[1.]])\n",
    "\n",
    "# A Toy Model\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.fc1 = nn.Linear(2, 2)\n",
    "        self.fc2 = nn.Linear(2, 1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        return x\n",
    "model = Net()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Step 2: Encrypt the Model and Data\n",
    "\n",
    "Encryption here comes in two steps. Since Secure Multi-Party Computation only works on integers, in order to operate over numbers with decimal points (such as weights and activations), we need to encode all of our numbers using Fixed Precision, which will give us several bits of decimal precision. We do this by calling .fix_precision().\n",
    "\n",
    "We can then call .share() as we have for other demos, which will encrypt all of the values by sharing them between Alice and Bob. Note that we also set requires_grad to True, which also adds a special autograd method for encrypted data. Indeed, since Secure Multi-Party Computation doesn't work on float values, we can't use the usual PyTorch autograd. Therefore, we need to add a special AutogradTensor node that computes the gradient graph for backpropagation. You can print any of this element to see that it includes an AutogradTensor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We encode everything\n",
    "data = data.fix_precision().share(bob, alice, crypto_provider=james, requires_grad=True)\n",
    "target = target.fix_precision().share(bob, alice, crypto_provider=james, requires_grad=True)\n",
    "model = model.fix_precision().share(bob, alice, crypto_provider=james, requires_grad=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Step 3: Train\n",
    "\n",
    "And now we can train using simple tensor logic."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt = optim.SGD(params=model.parameters(),lr=0.1).fix_precision()\n",
    "\n",
    "for iter in range(20):\n",
    "    # 1) erase previous gradients (if they exist)\n",
    "    opt.zero_grad()\n",
    "\n",
    "    # 2) make a prediction\n",
    "    pred = model(data)\n",
    "\n",
    "    # 3) calculate how much we missed\n",
    "    loss = ((pred - target)**2).sum()\n",
    "\n",
    "    # 4) figure out which weights caused us to miss\n",
    "    loss.backward()\n",
    "\n",
    "    # 5) change those weights\n",
    "    opt.step()\n",
    "\n",
    "    # 6) print our progress\n",
    "    print(loss.get().float_precision())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The loss indeed decreased!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Impact of fixed precision\n",
    "You might wonder how encrypting everything impacts the decreasing loss. Actually, because the theoretical computation is the same, the numbers are very close to non-encrypted training. You can verify this by running the same example without encryption and with a deterministic initialisation of the model like this one in the model `__init__`:\n",
    "```\n",
    "with torch.no_grad():\n",
    "    self.fc1.weight.set_(torch.tensor([[ 0.0738, -0.2109],[-0.1579,  0.3174]], requires_grad=True))\n",
    "    self.fc1.bias.set_(torch.tensor([0.,0.1], requires_grad=True))\n",
    "    self.fc2.weight.set_(torch.tensor([[-0.5368,  0.7050]], requires_grad=True))\n",
    "    self.fc2.bias.set_(torch.tensor([-0.0343], requires_grad=True))\n",
    "```\n",
    "\n",
    "The slight difference you might observe is due to the rounding of values performed while transforming to fixed precision. The default `precision_fractional` is 3 and if you get it down to 2 the divergence with clear text training increases, while it reduces if you choose `precision_fractional = 4`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Congratulations!!! - Time to Join the Community!\n",
    "\n",
    "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!\n",
    "\n",
    "### Star PySyft on Github\n",
    "\n",
    "The easiest way to help our community is just by starring the Repos! This helps raise awareness of the cool tools we're building.\n",
    "\n",
    "- [Star PySyft](https://github.com/OpenMined/PySyft)\n",
    "\n",
    "### Join our Slack!\n",
    "\n",
    "The best way to keep up to date on the latest advancements is to join our community! You can do so by filling out the form at [http://slack.openmined.org](http://slack.openmined.org)\n",
    "\n",
    "### Join a Code Project!\n",
    "\n",
    "The best way to contribute to our community is to become a code contributor! At any time you can go to PySyft Github Issues page and filter for \"Projects\". This will show you all the top level Tickets giving an overview of what projects you can join! If you don't want to join a project, but you would like to do a bit of coding, you can also look for more \"one off\" mini-projects by searching for github issues marked \"good first issue\".\n",
    "\n",
    "- [PySyft Projects](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3AProject)\n",
    "- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "\n",
    "### Donate\n",
    "\n",
    "If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!\n",
    "\n",
    "[OpenMined's Open Collective Page](https://opencollective.com/openmined)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
